-
Notifications
You must be signed in to change notification settings - Fork 245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Truncated gamma #1187
base: master
Are you sure you want to change the base?
Truncated gamma #1187
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
def icdf(self, q): | ||
# https://github.com/pyro-ppl/numpyro/issues/969 | ||
from numpyro.distributions.util import gammaincinv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can move this import to the top.
@@ -411,3 +412,327 @@ def tree_flatten(self): | |||
@classmethod | |||
def tree_unflatten(cls, aux_data, params): | |||
return cls(batch_shape=aux_data) | |||
|
|||
|
|||
def TruncatedGamma(base_gamma, low=None, high=None, validate_args=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to expose the parameters of Gamma here (TruncatedGamma(concentration, rate, low=..., high=...), rather than using a nested pattern. There are a couple of benefits with that:
- parameters of the distribution is defined probably in
args_constraints
- it is easier to test
- no need to have flatten/unflatten logic
base_gamma = Gamma.tree_unflatten(base_aux, base_flatten) | ||
return cls(base_gamma, low=low) | ||
|
||
@validate_sample |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, currently validate_sample
logic does not work with cdf
:(
# until jax/lax has direct implementation we'll need to rely on tfp | ||
# https://github.com/pyro-ppl/numpyro/issues/969 | ||
try: | ||
import tensorflow_probability as tfpm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can import tensorflow_probability.substrates.jax
directly, to make sure that jax
backend is installed.
return lprob - jnp.log(1.0 - lscale) | ||
|
||
def _scale_moment(self, t): | ||
assert t > -self.base_gamma.concentration |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work for jax arrays (which might have abstract values under jit compiling). You can use jnp.where
to mask out the invalid cases like this.
def log_prob(self, value): | ||
lprob = self.base_gamma.log_prob(value) | ||
lscale = self.base_gamma.cdf(self.low) | ||
return lprob - jnp.log(1.0 - lscale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use log1p(-lscale)
for a better numerical result
@quattro Looking the the PR is is the good shape - just have small comments above. Any chance we can have this in the next numpyro release? |
Will try my best. Should have some time closer to Thanksgiving holidays, does that fall before next release schedule? |
Absolutely, there is no plan for the release date yet. Thank you! |
Will we have this feature in the future? |
PR for issue #969 . Contains initial implementation that performs uniform sampling + inverse CDF of Left/Right/Doubly truncated Gamma. Relies on tensorflow functionality for igammainv function, which is not yet implemented at the lax/jax level (see jax-ml/jax#5350).
There is a test that fails, but it is not clear to me if this is purely a numerical issue with the uniform + iCDF sampling, or a larger issue that I missed at the time I implemented things.